package com.wlf.server.common.ws;

import cn.hutool.core.map.BiMap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.concurrent.ConcurrentHashMap;

/**
 * 这个类用来处理用户和连接的关联关系
 */
@Slf4j
@Component
public class NioWebSocketChannelPool {
  /**
   * 用来保存连接
   */
  private final DefaultChannelGroup channels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
  /**
   * 保存连接用户对应的对应的长连接Id，此为双向绑定
   */
  @Getter
  private final BiMap<String, ChannelId> bindUserMap = new BiMap<>(new ConcurrentHashMap<>());
  /**
   * 新增一个客户端通道
   *
   * @param channel 管道
   */
  public void addChannel(Channel channel) {
    channels.add(channel);
  }

  /**
   * 移除一个客户端连接通道
   *
   * @param channel 管道
   */
  public void removeChannel(Channel channel) {
    String mapKey = bindUserMap.getKey(channel.id());
    if (mapKey != null) {
      bindUserMap.remove(mapKey);
    }
    channels.remove(channel);
  }

  /**
   * 绑定用户
   * @param userId 用户Id
   * @param channel 管道连接
   */
  public void bindUser(String userId,Channel channel){
    bindUserMap.put(userId,channel.id());
  }

  /**
   * 向用户推送消息
   */
  public void sendToUser(String userId,WsBean data){
    ChannelId channelId = bindUserMap.get(userId);
    if (channelId != null){
      channels.find(channelId).writeAndFlush(new TextWebSocketFrame(data.toJson()));
    }
  }

}
